package cn.zifangsky.tree.avltree;

import java.util.List;
import java.util.function.BiConsumer;

/**
 * AVL树的定义
 *
 * @author zifangsky
 * @date 2018/11/28
 * @since 1.0.0
 */
public class AvlTree<T extends Comparable<? super T>> {
    /**
     * AVL树中可被允许的某个节点的两棵子树的最大高度差
     */
    private static final int ALLOWED_HEIGHT_DIFFERENCE = 1;

    /**
     * AVL树的根节点
     */
    private AvlNode root;

    public AvlTree() {
        root = null;
    }

    /**
     * 插入某个数据
     * @author zifangsky
     * @date 2018/11/28 15:06
     * @since 1.0.0
     * @param data 待插入的数据
     */
    public void insert(T data){
        root = this.insert(root, data);
    }

    /**
     * 批量插入
     * @author zifangsky
     * @date 2018/11/29 14:15
     * @since 1.0.0
     * @param list 待插入的数据
     */
    public void insert(List<T> list){
        if(list != null && list.size() > 0){
            list.forEach(this::insert);
        }
    }

    /**
     * 批量插入
     * @author zifangsky
     * @date 2018/11/29 14:15
     * @since 1.0.0
     * @param values 待插入的数据
     */
    public void insert(T...values){
        if(values != null && values.length > 0){
            for(T data : values){
                this.insert(data);
            }
        }
    }

    /**
     * 删除某个数据
     * @author zifangsky
     * @date 2018/11/28 15:08
     * @since 1.0.0
     * @param data 待删除的数据
     */
    public void remove(T data){
        if(this.isEmpty()){
            throw new RuntimeException("当前AVL树为空，不能执行删除操作");
        }

        root = this.remove(root, data);
    }

    /**
     * 删除最小节点
     * @author zifangsky
     * @date 2018/11/29 11:41
     * @since 1.0.0
     */
    public void removeMin(){
        if(this.isEmpty()){
            throw new RuntimeException("当前AVL树为空，不能执行删除操作");
        }

        root = this.removeMin(root);
    }

    /**
     * 删除最大节点
     * @author zifangsky
     * @date 2018/11/29 11:41
     * @since 1.0.0
     */
    public void removeMax(){
        if(this.isEmpty()){
            throw new RuntimeException("当前AVL树为空，不能执行删除操作");
        }

        root = this.removeMax(root);
    }

    /**
     * 获取当前AVL树的最小元素
     * @author zifangsky
     * @date 2018/11/28 17:39
     * @since 1.0.0
     * @return T
     */
    public T findMin(){
        if(this.isEmpty()){
            throw new RuntimeException("当前AVL树为空，不能执行查找操作");
        }

        return this.findMin(root).element;
    }

    /**
     * 获取当前AVL树的最大元素
     * @author zifangsky
     * @date 2018/11/28 17:39
     * @since 1.0.0
     * @return T
     */
    public T findMax(){
        if(this.isEmpty()){
            throw new RuntimeException("当前AVL树为空，不能执行查找操作");
        }

        return this.findMax(root).element;
    }

    /**
     * 返回某个数据是否在AVL树中
     * @author zifangsky
     * @date 2018/11/28 15:10
     * @since 1.0.0
     * @param data 待检查的数据
     * @return boolean
     */
    public boolean contains(T data){
        return this.contains(root, data);
    }

    /**
     * 清空AVL树
     * @author zifangsky
     * @date 2018/11/28 15:11
     * @since 1.0.0
     */
    public void clear(){
        this.root = null;
    }

    /**
     * 判断AVL树是否为空
     * @author zifangsky
     * @date 2018/11/28 15:13
     * @since 1.0.0
     * @return boolean
     */
    public boolean isEmpty() {
        return this.root == null;
    }

    /**
     * 中序遍历AVL树
     * @author zifangsky
     * @date 2018/11/28 15:18
     * @since 1.0.0
     * @param consumer consumer
     */
    public void forEach(BiConsumer<T,Integer> consumer){
        this.forEach(root, consumer);
    }



    /**
     * 查找某个节点下面最小的节点
     * @param node 指定节点
     */
    private AvlNode findMin(AvlNode node){
        if(node == null){
            return null;
        }

        while (node.left != null){
            node = node.left;
        }

        return node;
    }

    /**
     * 查找某个节点下面最大的节点
     * @param node 指定节点
     */
    private AvlNode findMax(AvlNode node){
        if(node == null){
            return null;
        }

        while (node.right != null){
            node = node.right;
        }

        return node;
    }

    /**
     * 返回某个数据是否存在
     * @param node 指定节点
     * @param data 待检查的数据
     */
    private boolean contains(AvlNode node, T data){
        while (node != null){
            //比较待检查数据与当前节点谁大谁小
            int compareResult = data.compareTo(node.element);

            if(compareResult == 0){
                return true;
            }else if(compareResult < 0){
                node = node.left;
            }else{
                node = node.right;
            }
        }

        return false;
    }

    /**
     * 插入某个数据
     * @param node 在该节点及其子节点中插入
     * @param data 待插入的数据
     */
    private AvlNode insert(AvlNode node, T data){
        if(node == null){
            return new AvlNode(data);
        }

        //比较待插入数据与当前节点谁大谁小
        int compareResult = data.compareTo(node.element);

        if(compareResult < 0){
            //在节点node的左子树中插入
            node.left = this.insert(node.left, data);
        }else if(compareResult > 0){
            //在节点node的右子树中插入
            node.right = this.insert(node.right, data);
        }

        //插入节点后可能出现失衡情况，所以需要平衡节点关系
        return this.balance(node);
    }

    /**
     * 删除最小节点
     * @param node 在该节点及其子节点中查找并删除最小节点
     */
    private AvlNode removeMin(AvlNode node){
        if(node == null){
            return null;
        }

        if(node.left.left != null){
            node.left = this.removeMin(node.left);
        }else{
            node.left = null;
        }

        return this.balance(node);
    }

    /**
     * 删除最大节点
     * @param node 在该节点及其子节点中查找并删除最大节点
     */
    private AvlNode removeMax(AvlNode node){
        if(node == null){
            return null;
        }

        if(node.right.right != null){
            node.right = this.removeMax(node.right);
        }else{
            node.right = null;
        }

        return this.balance(node);
    }

    /**
     * 删除某个数据
     * @param node 在该节点及其子节点中查找并删除数据
     * @param data 待删除的数据
     */
    private AvlNode remove(AvlNode node, T data){
        if(node == null){
            return null;
        }

        //比较待删除数据与当前节点谁大谁小
        int compareResult = data.compareTo(node.element);

        //待删除的数据在左子树中
        if(compareResult < 0){
            node.left = this.remove(node.left, data);
        }
        //待删除的数据在右子树中
        else if(compareResult > 0){
            node.right = this.remove(node.right, data);
        }
        //待删除的数据在当前节点，且当前节点有两个孩子节点
        else if(node.left != null && node.right != null){
            /**
             * 这里有两种选择策略，要么取左子树中的最大值当根节点
             * 要么取右子树的最小值当根节点
             */
            //查找node节点左子树中的最大值
            node.element = this.findMax(node.left).element;
            //移除该节点
            this.remove(node.left, node.element);
        }
        //待删除的数据在当前节点，且当前节点只有一个孩子节点
        else{
            node = (node.left != null) ? node.left : node.right;
        }

        //移除节点后可能出现失衡情况，所以需要平衡节点关系
        return this.balance(node);
    }

    /**
     * 中序遍历AVL树
     * @param node 遍历该节点及其子节点
     * @param consumer consumer
     */
    private void forEach(AvlNode node, BiConsumer<T,Integer> consumer){
        if(node != null){
            this.forEach(node.left, consumer);
            //处理数据
            consumer.accept(node.element, node.height);

            this.forEach(node.right, consumer);
        }
    }

    /**
     * 平衡节点
     * @param node 待平衡的节点
     */
    private AvlNode balance(AvlNode node){
        if(node == null){
            return  null;
        }

        //左子树太深，需要进行平衡
        if(this.getNodeHeight(node.left) - this.getNodeHeight(node.right) > ALLOWED_HEIGHT_DIFFERENCE){
            //左单旋转
            if(this.getNodeHeight(node.left.left) >= this.getNodeHeight(node.left.right)){
                node = this.rotateWithLeftChild(node);
            }
            //左双旋转
            else{
                node = this.doubleWithLeftChild(node);
            }
        }
        //右子树太深，需要进行平衡
        else if(this.getNodeHeight(node.right) - this.getNodeHeight(node.left) > ALLOWED_HEIGHT_DIFFERENCE){
            //右单旋转
            if(this.getNodeHeight(node.right.right) >= this.getNodeHeight(node.right.left)){
                node = this.rotateWithRightChild(node);
            }
            //右双旋转
            else{
                node = this.doubleWithRightChild(node);
            }
        }

        //计算节点高度
        node.height = Math.max(this.getNodeHeight(node.left), this.getNodeHeight(node.right)) + 1;

        return node;
    }

    /**
     * 根据左孩子节点做单旋转
     * @param k2 待旋转子树的根节点
     */
    private AvlNode rotateWithLeftChild(AvlNode k2){
        //k2的左孩子节点
        AvlNode k1 = k2.left;
        //将k1的右孩子挂到k2的左孩子节点
        k2.left = k1.right;
        //将k2挂到k1的右孩子节点
        k1.right = k2;
        //重新计算k2的高度
        k2.height = Math.max(this.getNodeHeight(k2.left), this.getNodeHeight(k2.right)) + 1;
        //重新计算k1的高度
        k1.height = Math.max(this.getNodeHeight(k1.left), this.getNodeHeight(k2)) + 1;

        return k1;
    }

    /**
     * 根据右孩子节点做单旋转
     * @param k1 待旋转子树的根节点
     */
    private AvlNode rotateWithRightChild(AvlNode k1){
        //k1的右孩子节点
        AvlNode k2 = k1.right;
        //将k2的左孩子挂到k1的右孩子节点
        k1.right = k2.left;
        //将k1挂到k2的左孩子节点
        k2.left = k1;
        //重新计算k1的高度
        k1.height = Math.max(this.getNodeHeight(k1.left), this.getNodeHeight(k1.right)) + 1;
        //重新计算k2的高度
        k2.height = Math.max(this.getNodeHeight(k1), this.getNodeHeight(k2.right)) + 1;

        return k2;
    }

    /**
     * 根据左孩子节点做双旋转
     * @param k3 待旋转子树的根节点
     */
    private AvlNode doubleWithLeftChild(AvlNode k3){
        //先围绕k3的左孩子节点做一次右单旋转
        k3.left = this.rotateWithRightChild(k3.left);
        //再围绕k3做一次左单旋转
        return this.rotateWithLeftChild(k3);
    }

    /**
     * 根据右孩子节点做双旋转
     * @param k1 待旋转子树的根节点
     */
    private AvlNode doubleWithRightChild(AvlNode k1){
        //先围绕k1的右孩子节点做一次左单旋转
        k1.right = this.rotateWithLeftChild(k1.right);
        //再围绕k1做一次右单旋转
        return this.rotateWithRightChild(k1);
    }

    /**
     * 返回指定节点的高度
     * @param node 在该节点及其子节点中插入
     */
    private int getNodeHeight(AvlNode node){
        return node == null ? -1 : node.height;
    }

    /**
     * 定义AVL树的单个节点
     */
    private class AvlNode {
        /**
         * 数据
         */
        T element;
        /**
         * 左孩子节点
         */
        AvlNode left;
        /**
         * 右孩子节点
         */
        AvlNode right;
        /**
         * 当前节点的高度，这里规定叶子节点的高度为 0
         */
        int height;

        AvlNode(T element) {
            this(element, null, null);
        }

        AvlNode(T element, AvlNode left, AvlNode right) {
            this.element = element;
            this.left = left;
            this.right = right;
            //初始高度为0
            this.height = 0;
        }

    }

}
